import os

import networkx as nx
from pgmpy.base.DAG import PDAG
from pgmpy.models.BayesianModel import BayesianNetwork
import numpy as np

def save_causal_graph(graph: nx.DiGraph, save_path: str, filename: str, save_gml: bool = True, save_npz: bool = True) -> None:
    """
    Save a networkx graph as a GML file.

    Parameters:
    graph (nx.DiGraph): The networkx graph to be saved.
    save_path (str): The path to save the GML file.
    filename (str): The name of the GML file.
    """

    if save_gml:
        # Save as GML file
        nx.write_graphml(
            graph,
            os.path.join(save_path, f"{filename}.gml"),
        )

    if save_npz:
        # Save as npz file
        # Convert the graph to a numpy array
        adj_matrix = nx.to_numpy_array(graph)

        # Save the numpy array as an npz file
        np.savetxt(os.path.join(save_path, f"{filename}.npz"), adj_matrix, delimiter=',')

    # Print success message
    n_nodes = len(list(graph.nodes()))
    n_edges = len(list(graph.edges()))
    print(f"Graph with {n_nodes} nodes and {n_edges} edges has been saved as {filename}.gml, {filename}.npz and {filename}.pkl in {save_path}.")
    return

def nx2gml(graph: nx.Graph, save_path: str, filename: str) -> nx.Graph:
    """
    Convert a networkx graph to a GML file.

    Parameters:
    graph (nx.Graph): The networkx graph to be converted.
    save_path (str): The path to save the GML file.
    filename (str): The name of the GML file.

    Returns:
    nx.Graph: The networkx graph loaded from the GML file.
    """

    # Remove all node attributes
    for node in graph.nodes:
        graph.nodes[node].clear()

    nx.write_graphml(
        graph,
        os.path.join(save_path, f"{filename}.gml"),
    )

    # Load the GML file
    gml_graph = nx.read_gml(os.path.join(save_path, f"{filename}.gml"))

    return gml_graph


def get_dag_from_causal_graph(causal_graph: nx.DiGraph) -> BayesianNetwork:
    """Since often the causal graph is not a DAG, we need to convert it to a DAG.

    Args:
        causal_graph (nx.DiGraph): Causal graph. It can be partially directed, especially after performing a latent projection.

    Returns:
        BayesianNetwork: _description_
    """
    # Get directed and undirected edges
    directed_edges = []
    undirected_edges = []
    edge_list = list(causal_graph.edges())

    for edge in causal_graph.edges():
        if (edge[1],edge[0]) not in edge_list:
            directed_edges.append(edge)
        else:
            if (edge[1], edge[0]) not in undirected_edges:
                undirected_edges.append((edge[0], edge[1]))

    pdag= PDAG(directed_ebunch=directed_edges, undirected_ebunch=undirected_edges)
    pdag.add_nodes_from(causal_graph.nodes())
    return pdag.to_dag()
